package sparkSql

import java.lang

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
import org.junit.{After, Before, Test}

class sparkSqlFunction {
	val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[3]")

	var outpath: String = "out"

	val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

	import util.MyPredef._

	@Before
	def init() {
		outpath.delete()
	}

	@After
	def after() {
		spark.stop()
	}

	/**
	 * 1、自定义 UDF
	 */
	@Test
	def udfTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		spark.udf.register("add_name", (name: String) => name + " UDF")

		val resDF: DataFrame = spark.sql("select add_name(name), * from user")

		resDF.show()

	}

	/**
	 * 1、自定义 UDAF
	 */
	@Test
	def udafTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		val udaf: AvgUDAFFunction = new AvgUDAFFunction

		spark.udf.register("my_avg", udaf)

		val resDF: DataFrame = spark.sql("select name, my_avg(age) from user group by name")

		resDF.show()

	}

	/**
	 * 1、自定义强类型 UDAF
	 */
	@Test
	def udafClassTest(): Unit = {
		val df: DataFrame = spark.read.json("in/user.json")
		df.createOrReplaceTempView("user")

		val udaf = new AvgClassUDAFFunction

		val age: TypedColumn[UserBean, Double] = udaf.toColumn.name("avg_age")

		import spark.implicits._
		val ds: Dataset[UserBean] = df.as[UserBean]

		ds.select(age).show()
	}
}


/**
 * 弱类型 UDAF
 */
class AvgUDAFFunction extends UserDefinedAggregateFunction {

	// 输入的数据结构
	override def inputSchema: StructType = {
		new StructType().add("age", LongType)
	}

	// 计算时的数据结构
	override def bufferSchema: StructType = {
		new StructType().add("sum", LongType).add("count", LongType)
	}

	// udaf 返回的数据类型
	override def dataType: DataType = DoubleType

	// 函数是否稳定
	override def deterministic: Boolean = true

	// 缓冲区的初始化
	override def initialize(buffer: MutableAggregationBuffer): Unit = {
		buffer(0) = 0L
		buffer(1) = 0L
	}

	//
	override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
		// sum
		buffer(0) = buffer.getLong(0) + input.getLong(0)
		// count
		buffer(1) = buffer.getLong(1) + 1
	}

	// 多个节点缓冲区的合并
	override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
		// sum
		buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
		// count
		buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
	}

	// 计算
	override def evaluate(buffer: Row): Any = {
		(buffer.getLong(0) / buffer.getLong(1)) toDouble
	}
}

/**
 * 强类型 UDAF
 */
case class UserBean(name: String, age: Long)
case class AvgBuffer(var sum: Int, var count: Int)

class AvgClassUDAFFunction extends Aggregator[UserBean, AvgBuffer, Double] {
	// 初始化
	override def zero: AvgBuffer = AvgBuffer(0, 0)

	// 分区内聚合
	override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
		b.count = b.count + 1
		b.sum = (b.sum + a.age).toInt
		b
	}

	// 分区间合并
	override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
		b1.sum = b1.sum + b2.sum
		b1.count = b1.count + b2.count
		b1
	}

	// 完成计算
	override def finish(reduction: AvgBuffer): Double = {
		reduction.sum.toDouble / reduction.count
	}

	//
	override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

	//
	override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}